#! /usr/bin/env python3

import urlearning.Record as Record

class Variable:
    
    def __init__(self, network, name, index):
        self.network = network
        self.name = name
        self.index = index
        self.parents = []
        self.arity = 0
        self.values = []
        self.metaInformation = {}
        self.parameters = [ [0] * 0 ] * 0
        
    def getIndex(self):
        return self.index
        
    def getName(self):
        return self.name
    
    def setArity(self, arity):
        self.arity = arity
        
    def getArity(self):
        return self.arity

    def addValue(self, value):
        self.arity += 1
        self.values.append(value)
        
    def setValues(self, values):
        self.values = values
    
    def getValues(self):
        return self.values
        
    def getValue(self, index):
        return self.values[index]

    def getValueIndex(self, value):
        for i in range(len(self.values)):
            if self.values[i].lower() == value.lower():
                return i
        return -1
        
    def putMetaInformation(self, key, value):
        self.metaInformation[key] = value
        
    def getMetaInformation(self, key):
        return self.metaInformation
        
    def getMetaInformationKeys(self):
        return self.metaInformation.keys()

    def getParents(self):
        return self.parents

    def setParents(self, parents):
        self.parents = parents

    def mergeParents(self, parents):
        #print("Merging variable parents; current parents: &{}&\nNew parents: %{}%".format(self.parents, parents))
        self.parents = sorted(list(set(self.parents + parents)))

    def getParentsString(self):
        return " ".join([ self.network.get(pIndex).getName() for pIndex in self.parents ])

    def addParent(self, index):
        self.parents.append(index)
        #self.parents.sort()

    def addParentName(self, parent):
        index = self.network.getVariableIndex(parent)
        self.addParent(index)

    def updateParameterSize(self):
        count = 1
        for p in self.parents:
            count *= self.network.get(p).getArity()

        self.parameters = [ [0] * self.arity for i in range(count) ]
            
    def getParameters(self):
        return self.parameters

    def getParametersSize(self):
        return len(self.parameters)

    def getFirstInstantiation(self):
        ins = Record.Record(self.network.size())
        for p in self.parents:
            ins.set(p, self.network.get(p).getValue(0))

        return ins

    def getNextParentInstantiation(self, instantiation):
        count = 0

        # KEEP AND EYE ON THIS; IT MAY BE WRONG TO REVERSE
        # BUT THE JAVA VERSION GOES BACKWARD
        # IT DOES NOT KEEP THE PARENT VARIABLES SORTED, THOUGH
        #
        # Based on that, I do not think sorting matters here.
        for p in reversed(self.parents):
            value = instantiation.get(p)
            i = self.network.get(p).getValueIndex(value)
            if i < self.network.getArity(p) -1:
                newValue = self.network.get(p).getValue(i+1)
                instantiation.set(p, newValue)
                return count
            else:
                newValue = self.network.get(p).getValue(0)
                instantiation.set(p, newValue)
                count += 1

        # then we were at the last parent instantiation
        return -1

    def getParentIndex(self, instantiation):
        total = 1
        pIndex = 0

        # for each parent

        # THIS SHOULD ACTUALLY USE THE HEADER IN CASE THE INDICES DO NOT MATCH
        for p in self.parents:
            parent = self.network.get(p)
            pValue = instantiation.get(p)
            pIndex += parent.getValueIndex(pValue) * total
            total *= parent.getArity()

        return pIndex

    def getParameter(self, pIndex, k):
        return self.parameters[pIndex][k]

    def setParameter(self, pIndex, k, value):
        self.parameters[pIndex][k] = value
    
    def print(self):
        depth = len(self.parents)
        printTabs(depth)
        print("variable: {}".format(self.name))

        printTabs(depth+1)
        print("parents:")
        for p in self.parents:
            printTabs(depth+2)
            print(self.network.get(p).name)
        
def printTabs(tabs):
    for i in range(tabs):
        print("\t", end='')
